{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Matplotlib created a temporary config/cache directory at /tmp/matplotlib-73_p9iw7 because the default path (/home/mikyas/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import cv2\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.image as mpimg \n",
    "import trt_pose.coco\n",
    "import math\n",
    "import os\n",
    "import numpy as np\n",
    "import traitlets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('preprocess/hand_pose.json', 'r') as f:\n",
    "    hand_pose = json.load(f)\n",
    "\n",
    "topology = trt_pose.coco.coco_category_to_topology(hand_pose)\n",
    "import trt_pose.models\n",
    "\n",
    "num_parts = len(hand_pose['keypoints'])\n",
    "num_links = len(hand_pose['skeleton'])\n",
    "\n",
    "model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()\n",
    "import torch\n",
    "\n",
    "\n",
    "WIDTH = 224\n",
    "HEIGHT = 224\n",
    "data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()\n",
    "\n",
    "if not os.path.exists('model/hand_pose_resnet18_att_244_244_trt.pth'):\n",
    "    MODEL_WEIGHTS = 'model/hand_pose_resnet18_att_244_244.pth'\n",
    "    model.load_state_dict(torch.load(MODEL_WEIGHTS))\n",
    "    import torch2trt\n",
    "    model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)\n",
    "    OPTIMIZED_MODEL = 'model/hand_pose_resnet18_att_244_244_trt.pth'\n",
    "    torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)\n",
    "\n",
    "\n",
    "OPTIMIZED_MODEL = 'model/hand_pose_resnet18_att_244_244_trt.pth'\n",
    "from torch2trt import TRTModule\n",
    "\n",
    "model_trt = TRTModule()\n",
    "model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trt_pose.draw_objects import DrawObjects\n",
    "from trt_pose.parse_objects import ParseObjects\n",
    "\n",
    "parse_objects = ParseObjects(topology,cmap_threshold=0.15, link_threshold=0.15)\n",
    "draw_objects = DrawObjects(topology)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torchvision.transforms as transforms\n",
    "import PIL.Image\n",
    "\n",
    "mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()\n",
    "std = torch.Tensor([0.229, 0.224, 0.225]).cuda()\n",
    "device = torch.device('cuda')\n",
    "\n",
    "def preprocess(image):\n",
    "    global device\n",
    "    device = torch.device('cuda')\n",
    "    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
    "    image = PIL.Image.fromarray(image)\n",
    "    image = transforms.functional.to_tensor(image).to(device)\n",
    "    image.sub_(mean[:, None, None]).div_(std[:, None, None])\n",
    "    return image[None, ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preprocessdata import preprocessdata\n",
    "preprocessdata = preprocessdata(topology, num_parts)\n",
    "from gesture_classifier import gesture_classifier\n",
    "gesture_classifier = gesture_classifier()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define a function that will preprocess the image, which is originally in BGR8 / HWC format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_joints(image, joints):\n",
    "    count = 0\n",
    "    for i in joints:\n",
    "        if i==[0,0]:\n",
    "            count+=1\n",
    "    if count>= 3:\n",
    "        return \n",
    "    for i in joints:\n",
    "        cv2.circle(image, (i[0],i[1]), 2, (0,0,255), 1)\n",
    "    cv2.circle(image, (joints[0][0],joints[0][1]), 2, (255,0,255), 1)\n",
    "    for i in hand_pose['skeleton']:\n",
    "        if joints[i[0]-1][0]==0 or joints[i[1]-1][0] == 0:\n",
    "            break\n",
    "        cv2.line(image, (joints[i[0]-1][0],joints[i[0]-1][1]), (joints[i[1]-1][0],joints[i[1]-1][1]), (0,255,0), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.usb_camera import USBCamera\n",
    "from jetcam.csi_camera import CSICamera\n",
    "from jetcam.utils import bgr8_to_jpeg\n",
    "\n",
    "camera = USBCamera(width=WIDTH, height=HEIGHT, capture_fps=30, capture_device=1)\n",
    "#camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=30)\n",
    "\n",
    "camera.running = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0833455fca64ca1be4515b072fd6941",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Image(value=b'', format='jpeg', height='224', width='224')"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import ipywidgets\n",
    "from IPython.display import display\n",
    "\n",
    "\n",
    "image_w = ipywidgets.Image(format='jpeg', width=224, height=224)\n",
    "display(image_w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute(change):\n",
    "    image = change['new']\n",
    "    data = preprocess(image)\n",
    "    cmap, paf = model_trt(data)\n",
    "    cmap, paf = cmap.detach().cpu(), paf.detach().cpu()\n",
    "    counts, objects, peaks = parse_objects(cmap, paf)\n",
    "    joints = preprocessdata.joints_inference(image, counts, objects, peaks)\n",
    "    draw_joints(image, joints)\n",
    "    #draw_objects(image, counts, objects, peaks)# try this for multiple hand pose prediction \n",
    "    image_w.value = bgr8_to_jpeg(image[:, ::-1, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "execute({'new': camera.value})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.observe(execute, names='value')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "camera.unobserve_all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#camera.running = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
